package com.example.harmonet.harmtorch;

public class Flatten implements  Layer {

    @Override
    public Tensor forward(Tensor in) throws Exception {
        int dim = 1;
        for(int i = 0; i < in.dim().length; ++i) {
            dim *= in.dim()[i];
        }
        Tensor out = new Tensor(dim);
        out.setTensor(in.tensor());
        return out;
    }

    @Override
    public int getParam() {
        return 0;
    }

    @Override
    public void init(float[] param) {
    }
}
